Chapter 5: Datasets and models

Read datasets and models

library("DALEX")
library("randomForest")
library("patchwork")
library("ggplot2")
set.seed(1313)

titanic_rf <- randomForest(survived ~ class + gender + age + sibsp + parch + fare + embarked, data = titanic_imputed)

library("rms")
titanic_lrm <- lrm(survived ~ class + gender + rcs(age) + sibsp + parch + fare + embarked, data = titanic_imputed)

henry <- data.frame(
         class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", 
                     "engineering crew", "restaurant staff", "victualling crew")),
         gender = factor("male", levels = c("female", "male")),
         age = 47,
         sibsp = 0,
         parch = 0,
         fare = 25,
         embarked = factor("Cherbourg", levels = c("Belfast",
                           "Cherbourg","Queenstown","Southampton"))
)
henry
johnny_d <- data.frame(
            class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew",
                        "engineering crew", "restaurant staff", "victualling crew")),
            gender = factor("male", levels = c("female", "male")),
            age = 8,
            sibsp = 0,
            parch = 0,
            fare = 72,
            embarked = factor("Southampton", levels = c("Belfast",
                        "Cherbourg","Queenstown","Southampton"))
)
johnny_d
titanic_lrm_exp <- DALEX::explain(model = titanic_lrm,  
                          data = titanic_imputed[, -9],
                             y = titanic_imputed$survived, 
                         label = "Logistic Regression")
## Preparation of a new explainer is initiated
##   -> model label       :  Logistic Regression 
##   -> data              :  2207  rows  8  cols 
##   -> target variable   :  2207  values 
##   -> predict function  :  yhat.lrm  will be used (  default  )
##   -> predicted values  :  numerical, min =  0.002671631 , mean =  0.3221568 , max =  0.9845724  
##   -> model_info        :  package rms , ver. 6.0.0 , task classification (  default  ) 
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -0.9845724 , mean =  -2.491758e-09 , max =  0.9715125  
##   A new explainer has been created! 
titanic_rf_exp <- DALEX::explain(model = titanic_rf,  
                          data = titanic_imputed[, -9],
                             y = titanic_imputed$survived, 
                         label = "Random Forest")
## Preparation of a new explainer is initiated
##   -> model label       :  Random Forest 
##   -> data              :  2207  rows  8  cols 
##   -> target variable   :  2207  values 
##   -> predict function  :  yhat.randomForest  will be used (  default  )
##   -> predicted values  :  numerical, min =  0.01590278 , mean =  0.3222722 , max =  0.9900173  
##   -> model_info        :  package randomForest , ver. 4.6.14 , task regression (  default  ) 
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -0.7970723 , mean =  -0.0001153935 , max =  0.8992474  
##   A new explainer has been created! 
titanic_rf_exp$model_info$type = "classification"

Chapter 7: Break-down Plots for Additive Attributions

Examples

bd_rf <- predict_parts(explainer = titanic_rf_exp,
                 new_observation = johnny_d,
              keep_distributions = TRUE,
                  order = c("class","age","gender","fare","parch","sibsp","embarked"),
                            type = "break_down")
bd_rf

Plot the break down plots

plot(bd_rf) 

plot(bd_rf, plot_distributions = TRUE) 

Basic use of the perdict_parts() function

bd_rf <- predict_parts(explainer = titanic_rf_exp,
                 new_observation = henry,
                            type = "break_down")
bd_rf

Plot the break down plots

plot(bd_rf) 

Advanced use of the predict_parts() function

bd_rf_order <- predict_parts(explainer = titanic_rf_exp,
                              new_observation = henry, 
                                         type = "break_down",
               order = c("class", "age", "gender", "fare", "parch", "sibsp", "embarked"))
plot(bd_rf_order, max_features = 3) 

bd_rf_distr <- predict_parts(explainer = titanic_rf_exp,
                              new_observation = henry, 
                                        type = "break_down",
          order = c("class", "age", "gender", "fare", "parch", "sibsp", "embarked"),
                          keep_distributions = TRUE)
plot(bd_rf_distr, plot_distributions = TRUE) 

Chapter 8: Break-down Plots for Interactions (iBreak-down Plots)

Examples

bd_rf <- predict_parts(explainer = titanic_rf_exp,
                        new_observation = johnny_d,
                                   type = "break_down_interactions")
bd_rf
plot(bd_rf) 

Code snippets for R

bd_rf <- predict_parts(explainer = titanic_rf_exp,
                        new_observation = henry,
                                   type = "break_down_interactions")
bd_rf
plot(bd_rf) 

Chapter 9: Shapley Additive Explanations (SHAP) and Average Variable Attributions

set.seed(13)

rsample <- lapply(1:10, function(i){
  new_order <- sample(1:7)
  bd <- predict_parts(titanic_rf_exp, johnny_d, order = new_order)
  bd$variable <- as.character(bd$variable)
  bd$variable[bd$variable == "embarked = Southampton"] = "embarked = S"
  bd$label = paste("random order no.", i)
  plot(bd) + scale_y_continuous(limits = c(0.1, 0.6), name = "", breaks = seq(0.1, 0.6, 0.1))
})

rsample[[1]] +
rsample[[2]] +
rsample[[3]] +
rsample[[4]] + 
rsample[[5]] + 
rsample[[6]] + 
rsample[[7]] + 
rsample[[8]] + 
rsample[[9]] + 
rsample[[10]] + plot_layout(ncol = 2)

shap_johnny <- predict_parts(titanic_rf_exp,
                 new_observation = johnny_d,
                 B = 25,
                 type = "shap")

Example: Titanic data

Code snippets for R

predict(titanic_rf_exp, henry)
##         1 
## 0.3081968
shap_henry <- predict_parts(explainer = titanic_rf_exp, 
                             new_observation = henry, 
                                        type = "shap",
                                           B = 25)
shap_henry
plot(shap_henry) 

plot(shap_henry, show_boxplots = FALSE) 

Chapter 10: Local Interpretable Model-agnostic Explanations (LIME)

The lime package

set.seed(1)
library("lime")
library("localModel")
lime_johnny <- predict_surrogate(titanic_rf_exp, 
                  johnny_d, 
                  n_features = 3, 
                  n_permutations = 1000,
                  type = "lime")
as.data.frame(lime_johnny)
plot(lime_johnny)

The localModel package

library("localModel")
lime_johnny <- predict_surrogate(titanic_rf_exp, 
                  new_observation = johnny_d, 
                  size = 1000, 
                  seed = 1,
                  type = "localModel")
lime_johnny[,1:3]

The iml package

library("iml")
library("localModel")
lime_johnny <- predict_surrogate(titanic_rf_exp, 
                  new_observation = johnny_d, 
                  k = 3, 
                  type = "iml")
lime_johnny$results
plot(lime_johnny) 

Chapter 11: Ceteris-paribus Profiles

Basic use of the predict_profile() function

library("DALEX")
cp_titanic_rf <- predict_profile(explainer = titanic_rf_exp, 
                                 new_observation = johnny_d)

cp_titanic_rf
library("ggplot2")
plot(cp_titanic_rf, variables = c("age", "fare")) 

plot(cp_titanic_rf, variables = c("class", "embarked"), variable_type = "categorical") 

Advanced use of the predict_profile() function

variable_splits = list(age = seq(0, 70, 0.1), fare = seq(0, 100, 0.1))
cp_titanic_rf <- predict_profile(explainer = titanic_rf_exp, 
                                    new_observation = henry,
                              variable_splits = variable_splits)
plot(cp_titanic_rf, variables = c("age", "fare")) + 
  ylim(0, 1) +
  ggtitle("Ceteris-paribus Profile", 
          "For the random-forest model, Titanic dataset, and Henry")

cp_titanic_rf2 <- predict_profile(explainer = titanic_rf_exp, 
                               new_observation = rbind(henry, johnny_d),
                               variable_splits = variable_splits)
library(ingredients)
plot(cp_titanic_rf2, color = "_ids_", variables = c("age", "fare")) + 
  scale_color_manual(name = "Passenger:", breaks = 1:2, 
            values = c("#4378bf", "#8bdcbe"), 
            labels = c("henry" , "johny_d")) + 
  ggtitle("Ceteris-paribus Profile", 
            "For the random-forest model, Titanic data, and Henry and Johnny D")

Comparison of models (challenger-champion analysis)

cp_titanic_rf <- predict_profile(titanic_rf_exp, henry, variable_splits = variable_splits)
cp_titanic_lmr <- predict_profile(titanic_lrm_exp, henry, variable_splits = variable_splits)
plot(cp_titanic_rf, cp_titanic_lmr, color = "_label_",  variables = c("age", "fare")) +
     ggtitle("Ceteris-paribus Profiles for Henry")

Chapter 11: Ceteris-paribus Oscillations

Examples

oscillations_equi <- predict_parts(titanic_rf_exp, henry, type = "oscillations_uni")
oscillations_emp <- predict_parts(titanic_rf_exp, henry, type = "oscillations_emp", N = 1000)
oscillations_equi$`_ids_` <- "Henry"
oscillations_emp$`_ids_` <- "Henry"

pl1 <- plot(oscillations_equi) + 
    ggtitle("CP Oscillations for uniform distribution", "")
pl2 <- plot(oscillations_emp) + 
    ggtitle("CP Oscillations for empirical distribution", "")

pl1 + pl2

Basic use of the predict_parts() function

oscillations_uniform <- predict_parts(explainer = titanic_rf_exp, 
                                new_observation = henry, 
                                           type = "oscillations_uni")
oscillations_uniform
oscillations_uniform$`_ids_` <- "Henry"
plot(oscillations_uniform) +
    ggtitle("Ceteris-paribus Oscillations", 
            "Expectation over uniform distribution (unique values)")

Advanced use of the predict_parts() function

oscillations_equidist <- predict_parts(titanic_rf_exp, henry, 
              variable_splits = list(age = seq(0, 65, 0.1),
                                    fare = seq(0, 200, 0.1),
                                  gender = unique(titanic_imputed$gender),
                                   class = unique(titanic_imputed$class)), 
                         type = "oscillations")
oscillations_equidist
oscillations_equidist$`_ids_` <- "Henry"
plot(oscillations_equidist) + 
    ggtitle("Ceteris-paribus Oscillations", 
            "Expectation over specified grid of points")

Session info

sessionInfo()
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Catalina 10.15.3
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] ingredients_1.3.1   gower_0.2.2         glmnet_4.0-2       
##  [4] Matrix_1.2-18       iml_0.10.0          localModel_0.5     
##  [7] lime_0.5.1          rms_6.0-0           SparseM_1.78       
## [10] Hmisc_4.4-0         Formula_1.2-3       survival_3.2-3     
## [13] lattice_0.20-41     ggplot2_3.3.2       patchwork_1.0.1    
## [16] randomForest_4.6-14 DALEX_1.3.1        
## 
## loaded via a namespace (and not attached):
##  [1] jsonlite_1.7.0      splines_4.0.2       foreach_1.5.0      
##  [4] shiny_1.5.0         assertthat_0.2.1    latticeExtra_0.6-29
##  [7] yaml_2.2.1          pillar_1.4.6        backports_1.1.8    
## [10] quantreg_5.55       glue_1.4.1          digest_0.6.25      
## [13] promises_1.1.1      RColorBrewer_1.1-2  checkmate_2.0.0    
## [16] colorspace_1.4-1    sandwich_2.5-1      httpuv_1.5.4       
## [19] htmltools_0.5.0     pkgconfig_2.0.3     xtable_1.8-4       
## [22] purrr_0.3.4         mvtnorm_1.1-1       scales_1.1.1       
## [25] later_1.1.0.1       jpeg_0.1-8.1        MatrixModels_0.4-1 
## [28] htmlTable_2.0.1     tibble_3.0.3        generics_0.0.2     
## [31] farver_2.0.3        ellipsis_0.3.1      TH.data_1.0-10     
## [34] withr_2.2.0         nnet_7.3-14         mime_0.9           
## [37] magrittr_1.5        crayon_1.3.4        polspline_1.1.19   
## [40] evaluate_0.14       nlme_3.1-148        MASS_7.3-51.6      
## [43] foreign_0.8-80      tools_4.0.2         data.table_1.12.8  
## [46] lifecycle_0.2.0     multcomp_1.4-13     stringr_1.4.0      
## [49] munsell_0.5.0       prediction_0.3.14   cluster_2.1.0      
## [52] compiler_4.0.2      inum_1.0-1          rlang_0.4.7        
## [55] grid_4.0.2          iterators_1.0.12    rstudioapi_0.11    
## [58] htmlwidgets_1.5.1   base64enc_0.1-3     labeling_0.3       
## [61] rmarkdown_2.3       partykit_1.2-9      gtable_0.3.0       
## [64] codetools_0.2-16    Metrics_0.1.4       R6_2.4.1           
## [67] gridExtra_2.3       zoo_1.8-8           knitr_1.29         
## [70] dplyr_1.0.0         fastmap_1.0.1       libcoin_1.0-5      
## [73] shinythemes_1.1.2   iBreakDown_1.3.0    shape_1.4.4        
## [76] stringi_1.4.6       Rcpp_1.0.5          vctrs_0.3.2        
## [79] rpart_4.1-15        acepack_1.4.1       png_0.1-7          
## [82] tidyselect_1.1.0    xfun_0.15